/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/accessor.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/script/printer/ir_docsifier.h>

#include <sstream>

#include "./utils.h"

namespace tvm {
namespace script {
namespace printer {

TVM_FFI_STATIC_INIT_BLOCK() {
  FrameNode::RegisterReflection();
  IRDocsifierNode::RegisterReflection();
}

IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame,
                              const ffi::String& name_hint) {
  if (auto it = obj2info.find(obj); it != obj2info.end()) {
    // TVM's IR dialects do not allow multiple definitions of the same
    // variable within an IRModule.  This branch can only be reached
    // when printing ill-formed inputs.
    //
    // However, the printer is different from most utilities, as it
    // may neither assume that its input is well-formed, nor may it
    // throw an exception if the input is ill-formed.  The printer is
    // often used for debugging, where logging and printouts of an
    // IRModule are essential.  In these cases, throwing an error
    // would prevent a developer from determining why an IRModule is
    // ill-formed.
    return IdDoc(it->second.name.value());
  }

  ffi::String name = name_hint;
  if (cfg->show_object_address) {
    std::stringstream stream;
    stream << name << "_" << obj.get();
    name = stream.str();
  }
  name = GenerateUniqueName(name, this->defined_names);
  this->defined_names.insert(name);
  DocCreator doc_factory = [name]() { return IdDoc(name); };
  obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}});
  IdDoc def_doc(name);
  frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
  return def_doc;
}

void IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, DocCreator doc_factory) {
  ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj;
  obj2info.insert({obj, VariableInfo{std::move(doc_factory), std::nullopt}});
  frame->AddExitCallback([this, obj]() { this->RemoveVar(obj); });
}

ffi::Optional<ExprDoc> IRDocsifierNode::GetVarDoc(const ObjectRef& obj) const {
  auto it = obj2info.find(obj);
  if (it == obj2info.end()) {
    return std::nullopt;
  }
  return it->second.creator();
}

ExprDoc IRDocsifierNode::AddMetadata(const ffi::Any& obj) {
  ICHECK(obj != nullptr) << "TypeError: Cannot add nullptr to metadata";
  ffi::String key = obj.GetTypeKey();
  ffi::Array<ffi::Any>& array = metadata[key];
  int index = std::find_if(array.begin(), array.end(),
                           [&](const ffi::Any& a) { return ffi::AnyEqual()(a, obj); }) -
              array.begin();
  if (index == static_cast<int>(array.size())) {
    array.push_back(obj);
  }
  return IdDoc(
      "metadata")[{LiteralDoc::Str(key, std::nullopt)}][{LiteralDoc::Int(index, std::nullopt)}];
}

void IRDocsifierNode::AddGlobalInfo(const ffi::String& name, const GlobalInfo& ginfo) {
  ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos";
  ffi::Array<GlobalInfo>& array = global_infos[name];
  array.push_back(ginfo);
}

bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }

void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
  auto it = obj2info.find(obj);
  ICHECK(it != obj2info.end()) << "No such object: " << obj;
  if (it->second.name.has_value()) {
    defined_names.erase(it->second.name.value());
  }
  obj2info.erase(it);
}

void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root,
                                      ffi::TypedFunction<bool(ObjectRef)> is_var) {
  class Visitor {
   public:
    void operator()(ObjectRef obj) { this->VisitObjectRef(obj); }

   private:
    void RecursiveVisitAny(ffi::Any* value) {
      if (std::optional<ObjectRef> opt = value->as<ObjectRef>()) {
        this->VisitObjectRef(*opt);
      }
    }
    void VisitObjectRef(ObjectRef obj) {
      if (!obj.defined()) {
        return;
      }
      if (visited_.count(obj.get())) {
        if (is_var(obj)) {
          HandleVar(obj.get());
        }
        return;
      }
      visited_.insert(obj.get());
      stack_.push_back(obj.get());
      if (obj->IsInstance<ffi::ArrayObj>()) {
        const ffi::ArrayObj* array = static_cast<const ffi::ArrayObj*>(obj.get());
        for (Any element : *array) {
          this->RecursiveVisitAny(&element);
        }
      } else if (obj->IsInstance<ffi::MapObj>()) {
        const ffi::MapObj* map = static_cast<const ffi::MapObj*>(obj.get());
        for (std::pair<Any, Any> kv : *map) {
          this->RecursiveVisitAny(&kv.first);
          this->RecursiveVisitAny(&kv.second);
        }
      } else {
        const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index());
        if (tinfo->metadata != nullptr) {
          ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) {
            Any field_value = ffi::reflection::FieldGetter(field_info)(obj);
            this->RecursiveVisitAny(&field_value);
          });
        }
      }
      if (is_var(obj)) {
        HandleVar(obj.get());
      }
      stack_.pop_back();
    }

    void HandleVar(const Object* var) {
      if (common_prefix.count(var) == 0) {
        common_prefix[var] = stack_;
        return;
      }
      std::vector<const Object*>& a = common_prefix[var];
      std::vector<const Object*>& b = stack_;
      int n = std::min(a.size(), b.size());
      for (int i = 0; i < n; ++i) {
        if (a[i] != b[i]) {
          a.resize(i);
          break;
        }
      }
    }

    std::vector<const Object*> stack_;
    std::unordered_set<const Object*> visited_;

   public:
    ffi::TypedFunction<bool(ObjectRef)> is_var;
    std::unordered_map<const Object*, std::vector<const Object*>> common_prefix;
  };
  Visitor visitor;
  visitor.is_var = is_var;
  visitor(root);
  this->common_prefix = std::move(visitor.common_prefix);
}

IRDocsifier::IRDocsifier(const PrinterConfig& cfg) {
  auto n = ffi::make_object<IRDocsifierNode>();
  n->cfg = cfg;
  n->dispatch_tokens.push_back("");
  // Define builtin keywords according to cfg.
  for (const ffi::String& keyword : cfg->GetBuiltinKeywords()) {
    n->defined_names.insert(keyword);
  }
  data_ = std::move(n);
}

IRDocsifier::FType& IRDocsifier::vtable() {
  static IRDocsifier::FType inst;
  return inst;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
    .set_fallback([](ObjectRef obj, AccessPath p, IRDocsifier d) -> Doc {
      return d->AddMetadata(obj);
    });

}  // namespace printer
}  // namespace script
}  // namespace tvm
